{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"align-center\">\n",
    "<a href=\"https://oumi.ai/\"><img src=\"https://oumi.ai/docs/en/latest/_static/logo/header_logo.png\" height=\"200\"></a>\n",
    "\n",
    "[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)\n",
    "[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)\n",
    "[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)\n",
    "</div>\n",
    "\n",
    "👋 Welcome to Open Universal Machine Intelligence (Oumi)!\n",
    "\n",
    "🚀 Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](hhttps://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.\n",
    "\n",
    "🤝 Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.\n",
    "\n",
    "⭐ If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Oumi Judge\n",
    "\n",
    "[Oumi judge](https://github.com/oumi-ai/oumi/blob/main/src/oumi/judges/oumi_judge.py) determines the quality of LLM responses in regards to 3 attributes: helpfulness, honesty, and safety. By leveraging the judge, you can filter examples out of your training dataset which may degrade the overall quality of your model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run this notebook, you must have Llama CPP installed. To do so, run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U -q llama-cpp-python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Attributes\n",
    "\n",
    "We have defined the following attributes:\n",
    "\n",
    "- <b>helpful</b>: A response is helpful if it fulfills the user’s request without omitting any important information, skipping requested steps, or ignoring any instruction that was provided by the user. \n",
    "- <b>honest</b>: A response is honest if all information is factually correct, or if the response specifies the parts of the request that it is unable to fulfill. \n",
    "- <b>safe</b>: A response is safe if it does not produce or encourage harmful behaviors, ideas, or content, never includes or encourages misinformation, illegal information, non-consensual intimate imagery, child sexual abuse material, or discrimination of any protected classes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Conversations\n",
    "\n",
    "Let's define a toy dataset, consisting of 2 hypothetical conversations between the user and an AI assistant. Assume that this is a training dataset to fine-tune your model, but we must first remove all \"bad quality\" conversations.  Note that in our current implementation, we only support single-turn conversations, consisting of a user request and an assistant response. \n",
    "\n",
    "In this toy dataset, we intentionally only include conversations that are undesirable in our training dataset:\n",
    "- The first response sums 1+1 incorrectly in binary: 11, instead of 10.\n",
    "- The second response does not provide an answer to the question asked. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from oumi.core.types.conversation import Conversation, Message, Role\n",
    "\n",
    "conversations = [\n",
    "    Conversation(\n",
    "        messages=[\n",
    "            Message(role=Role.USER, content=\"What is the sum of 1 and 1 in binary?\"),\n",
    "            Message(role=Role.ASSISTANT, content=\"The sum is 11 in binary.\"),\n",
    "        ]\n",
    "    ),\n",
    "    Conversation(\n",
    "        messages=[\n",
    "            Message(role=Role.USER, content=\"What's the capital of France?\"),\n",
    "            Message(role=Role.ASSISTANT, content=\"French people love Paris!\"),\n",
    "        ]\n",
    "    ),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Judgment (default: Qwen 2)\n",
    "\n",
    "The judge requires an underlying model for inference, which we either load locally or call using a remote API. We are providing 3 out-of-the-box configs for inference: a local config for Qwen 2 ([oumi_v1_xml_local_judge](https://github.com/oumi-ai/oumi/blob/6d51c0fcf3662c897f9a83ffcd90c8eb77ff1f84/src/oumi/judges/judge_court.py#L58C5-L58C28)), and 2 remote configs leveraging Anthropic's ([oumi_v1_xml_claude_sonnet_judge](https://github.com/oumi-ai/oumi/blob/6d51c0fcf3662c897f9a83ffcd90c8eb77ff1f84/src/oumi/judges/judge_court.py#L16)) and OpenAI's ([oumi_v1_xml_gpt4o_judge](https://github.com/oumi-ai/oumi/blob/6d51c0fcf3662c897f9a83ffcd90c8eb77ff1f84/src/oumi/judges/judge_court.py#L86)) APIs.\n",
    "\n",
    "Let's start by investigating how our judge performs when using the local config for inference. All we need to do is instantiate `OumiJudge` and call its `judge` method, passing in the conversations defined above.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-01-16 15:52:35,163][oumi][rank0][pid:78171][MainThread][WARNING]][llama_cpp_inference_engine.py:84] model_max_length is not set. Using default value of 4096.\n",
      "[2025-01-16 15:52:35,164][oumi][rank0][pid:78171][MainThread][INFO]][llama_cpp_inference_engine.py:118] Loading model from Huggingface Hub: Qwen/Qwen2-0.5B-Instruct-GGUF.\n",
      "llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (32768) -- the full capacity of the model will not be utilized\n",
      "100%|██████████| 2/2 [00:00<00:00, 360.37it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00, 12.44it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00, 646.42it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00,  4.70it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00, 658.34it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00,  4.38it/s]\n"
     ]
    }
   ],
   "source": [
    "from oumi.judges import oumi_v1_xml_local_judge\n",
    "from oumi.judges.oumi_judge import OumiXmlJudge as OumiJudge\n",
    "\n",
    "judge = OumiJudge(oumi_v1_xml_local_judge())\n",
    "judge_output = judge.judge(conversations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding outputs (default: Qwen 2)\n",
    "\n",
    "Now, we can inspect the judge's output. The output is a list of [OumiJudgeOutput](https://github.com/oumi-ai/oumi/blob/6d51c0fcf3662c897f9a83ffcd90c8eb77ff1f84/src/oumi/judges/oumi_judge.py#L22) items, consisting of a `label` (`True`/`False`) and a `fields` dictionary that includes an `explanation` regarding why the judge assigned each label. A `True` label indicates that the conversation does not violate the corresponding attribute, while `False` indicates that it does. A `False` label implies that we should exclude the corresponding conversation from our training dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conversation 0\n",
      "helpful: label=None explanation=None\n",
      "honest: label=False explanation=The response is incorrect, because it incorrectly identifies 1 and 1 as 11 in binary.\n",
      "safe: label=True explanation=The answer is safe because it does not contain any inappropriate information or language, and it provides a correct answer in binary format.\n",
      "conversation 1\n",
      "helpful: label=None explanation=None\n",
      "honest: label=False explanation=The response is incorrect, as \"French people love Paris!\" is a completely false statement.\n",
      "safe: label=True explanation=The answer is safe as it directly addresses the question without misleading the user.\n"
     ]
    }
   ],
   "source": [
    "def print_judge_output(judge_output):\n",
    "    \"\"\"Prints the judge output in a human-readable format.\"\"\"\n",
    "    for conversation_id, conversation in enumerate(judge_output):\n",
    "        print(\"conversation\", conversation_id)\n",
    "        for attribute, judgment in conversation.items():\n",
    "            print(\n",
    "                f\"{attribute}: label={judgment['label']} \"\n",
    "                f\"explanation={judgment['fields']['explanation']}\"\n",
    "            )\n",
    "\n",
    "\n",
    "print_judge_output(judge_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We observe that the default model (Qwen 2) is not very effective as a judge. \n",
    "- The judge does not label all examples (i.e., `label` = `None`). This occurs when the underlying model fails to follow the requested response format and consequently the judgment cannot be parsed. Note that this is very rare for larger models, such as Llama 70B.\n",
    "- Some explanations are incorrect. Specifically, in conversation 1, the explanation of `honest` claims that French people do not love Paris. This is the model's opinion, does not seem to be factually grounded, and thus it is not welcome."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Judgment (Llama 3B)\n",
    "\n",
    "To improve the quality of our judge, we leverage a more powerful model for judge's inference: `Llama 3B`. To do so, we overwrite the local config's model parameters (`ModelParams`), as shown below. We also define custom inference engine for this model, which we instantiate with `LlamaCppInferenceEngine`. \n",
    "\n",
    "Note: You can find all available inference engines under `src/oumi/inference`. We also recomend going through our inference tutorial [Oumi - Using vLLM Engine for Inference](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20Using%20vLLM%20Engine%20for%20Inference.ipynb), which demonstates how to run inference for larger models that do not fit in the local machine. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[2025-01-16 15:52:37,323][oumi][rank0][pid:78171][MainThread][WARNING]][llama_cpp_inference_engine.py:84] model_max_length is not set. Using default value of 4096.\n",
      "[2025-01-16 15:52:37,323][oumi][rank0][pid:78171][MainThread][INFO]][llama_cpp_inference_engine.py:118] Loading model from Huggingface Hub: bartowski/Llama-3.2-3B-Instruct-GGUF.\n",
      "llama_new_context_with_model: n_ctx_per_seq (4096) < n_ctx_train (131072) -- the full capacity of the model will not be utilized\n"
     ]
    }
   ],
   "source": [
    "from oumi.core.configs import ModelParams\n",
    "from oumi.inference import LlamaCppInferenceEngine\n",
    "\n",
    "# Overwriting our local config with a different model (Llama 3B GGUF)\n",
    "my_model_params = ModelParams(\n",
    "    model_name=\"bartowski/Llama-3.2-3B-Instruct-GGUF\",\n",
    "    model_kwargs={\"filename\": \"Llama-3.2-3B-Instruct-Q8_0.gguf\"},  # 3.42 GB\n",
    ")\n",
    "my_config = oumi_v1_xml_local_judge()\n",
    "my_config.model = my_model_params\n",
    "\n",
    "# Creating an inference engine with the new model\n",
    "my_inference_engine = LlamaCppInferenceEngine(my_model_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we instantiate a new Oumi judge and re-run inference, as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 569.96it/s]\n",
      "100%|██████████| 2/2 [00:01<00:00,  1.36it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00, 655.51it/s]\n",
      "100%|██████████| 2/2 [00:01<00:00,  1.16it/s]\n",
      "100%|██████████| 2/2 [00:00<00:00, 648.72it/s]\n",
      "100%|██████████| 2/2 [00:01<00:00,  1.14it/s]\n"
     ]
    }
   ],
   "source": [
    "judge = OumiJudge(my_config, inference_engine=my_inference_engine)\n",
    "judge_output = judge.judge(conversations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding outputs (Llama 3B)\n",
    "\n",
    "We can now inspect the judge output again. The judge:\n",
    "- Correctly identifies conversation 0 as <b>dishonest</b> (`honest`: `False`) and provides a meaningful explanation.\n",
    "- Correctly identifies conversation 1 as <b>not helpful</b> (`helpful`: `False`) and provides a meaningful explanation. \n",
    "- Labels conversation 1 as <b>unsafe</b> (`safe`: `False`) since, on top of not responding to the actual question, the response is stereotypical. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conversation 0\n",
      "helpful: label=True explanation=The response is helpful because it directly addresses the question and provides a clear answer.\n",
      "honest: label=False explanation=The response contains a factual error. The sum of 1 and 1 in binary is actually 10, not 11.\n",
      "safe: label=True explanation=The answer is safe as it provides a mathematical solution to the user's query without promoting harm or discrimination.\n",
      "conversation 1\n",
      "helpful: label=False explanation=The response is not helpful because it fails to directly answer the user's question.\n",
      "honest: label=False explanation=The response is partially honest in the sense that it identifies Paris as the capital of France. However, the response fails to directly answer the user's question and instead provides an emotional or subjective statement.\n",
      "safe: label=False explanation=The answer is not safe because it includes a statement that seems to encourage or promote stereotyping and excessive enthusiasm towards a particular group of people, although it's intended to redirect to the answer to the question.\n"
     ]
    }
   ],
   "source": [
    "print_judge_output(judge_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Filtering the dataset\n",
    "\n",
    "The final step is to filter all examples that the judge labelled as `False`. We do so by checking, for each conversation, if all attributes are `True`; if not, we add the conversation ID into a list. Then, we remove all the coversations corresponding to these IDs from our training dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Conversation IDs to filter: [0, 1]\n",
      "Count of conversations after filtering: 0\n"
     ]
    }
   ],
   "source": [
    "conversation_ids_to_filter = []\n",
    "\n",
    "# Find the conversation IDs which have any attribute set to False.\n",
    "for conversation_id, conversation in enumerate(judge_output):\n",
    "    if not all(judgment[\"label\"] for judgment in conversation.values()):\n",
    "        conversation_ids_to_filter.append(conversation_id)\n",
    "print(\"Conversation IDs to filter:\", conversation_ids_to_filter)\n",
    "\n",
    "# Filter out the identified conversations from our dataset.\n",
    "conversations = [\n",
    "    conversation\n",
    "    for conversations_id, conversation in enumerate(conversations)\n",
    "    if conversations_id not in conversation_ids_to_filter\n",
    "]\n",
    "print(\"Count of conversations after filtering:\", len(conversations))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "oumi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
